HDU 5785 Interesting(Manacher | 回文树)
题意:
$给定N\le 10^6的字符串,现在寻找所有三元组(i, j, k),1\le i\le j<k\le N$
$使得s[i\ldots j]和s[j+1\ldots k]都是回文串,求\sum\sum i\times k mod 10^9+7$
分析:
$赛上做法比较那啥,回文树硬肝的,a、b为某2个回文串长度$
$[i\ldots j-1]、[j\ldots k]$
$\sum\sum i\times k=\sum\sum (j-1-a+1)(j+b-1)$
$=\sum\sum (j-a)(j+b-1)$
$=preCnt\times sufCnt\times j^2$
$+(preCnt\times (sufSum-sufCnt)-preSum\times sufCnt)j$
$-preSum\times(sufSum-sufCnt)$
$preCnt[i]:=以i结尾的回文串个数,preSum[i]:=以i结尾的回文串的长度和$
$suf同理$
$然后回文树预处理一下就做完了,时间复杂度O(n)$
$\sum\sum i\times k=\sum i\times \sum k$
$Manacher的话直接预处理sum[2][i]:=0开头,1结尾的右/左端点的和$
$对于一个以i为中心的延伸距离为p[i]的最长回文串$
$显然l=i-(p[i]-1),r=i+p[i]-1$
$对于sum[0][i],i\in[l, i]右端点的贡献是r\sim i$
$这是一个首项为r,公差为-1的等差数列$
$由于所有更新都是静态的,窝萌可以partial sum搞一波$
$对于一个更新[L, R],首项为a,公差为d的等差数列$
$delta数组记录公差,sum数组记录结果$delta[L+1] += d // [L+1, R]的区间有公差
delta[R+1] -= d
sum[L] += a
sum[R+1] -= a + (R-L)*d //这里要去掉公差累计的影响
$累计的时候累计上公差就可以了$
$直接在Manacher数组上搞就行,最后除2就好$
$然后就做完了$
代码一:
//
// Created by TaoSama on 2016-08-02
// Copyright (c) 2016 TaoSama. All rights reserved.
//
#pragma comment(linker, "/STACK:102400000,102400000")
#include <algorithm>
#include <cctype>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <ctime>
#include <iomanip>
#include <iostream>
#include <map>
#include <queue>
#include <string>
#include <set>
#include <vector>
using namespace std;
#define pr(x) cout << #x << " = " << x << " "
#define prln(x) cout << #x << " = " << x << endl
const int N = 1e6 + 10, INF = 0x3f3f3f3f, MOD = 1e9 + 7;
struct PalindromicTree {
static const int M = 1e6 + 10, S = 26;
int n, sz, last;
int nxt[M][S], fail[M], len[M];
char s[M];
int cnt[M], sum[M];
int newnode(int l) {
len[sz] = l;
sum[sz] = cnt[sz] = 0;
memset(nxt[sz], 0, sizeof(nxt[sz]));
return sz++;
}
void init() {
sz = last = 0;
newnode(0); newnode(-1);
s[n = 0] = -1;
fail[0] = 1;
}
int getfail(int u) {
while(s[n - len[u] - 1] != s[n]) u = fail[u];
return u;
}
pair<int, int> add(int c) {
s[++n] = c;
int u = getfail(last);
int& v = nxt[u][c];
if(!v) {
int cur = newnode(len[u] + 2);
fail[cur] = nxt[getfail(fail[u])][c];
v = cur;
// pr(len[fail[v]]); prln(len[v]);
cnt[v] = cnt[fail[v]] + 1;
sum[v] = sum[fail[v]] + len[v];
if(sum[v] >= MOD) sum[v] -= MOD;
// prln(sum[v]);
}
last = v;
return {cnt[v], sum[v]};
}
} pt;
int n;
char s[N];
int preCnt[N], preSum[N];
typedef long long LL;
LL mul(LL x, LL y) {
return x * y % MOD;
}
int main() {
#ifdef LOCAL
freopen("C:\\Users\\TaoSama\\Desktop\\in.txt", "r", stdin);
// freopen("C:\\Users\\TaoSama\\Desktop\\out.txt","w",stdout);
#endif
ios_base::sync_with_stdio(0);
while(scanf("%s", s + 1) == 1) {
n = strlen(s + 1);
pt.init();
for(int i = 1; i <= n; ++i) {
auto ret = pt.add(s[i] - 'a');
tie(preCnt[i], preSum[i]) = ret;
}
pt.init();
LL ans = 0;
for(int i = n; i > 1; --i) {
auto ret = pt.add(s[i] - 'a');
int sufCnt, sufSum;
tie(sufCnt, sufSum) = ret;
// prln(i);
// printf("%d %d %d %d\n", preCnt[i - 1], preSum[i - 1], sufCnt, sufSum);
LL sqI = mul(mul(mul(i, i), preCnt[i - 1]), sufCnt);
LL mid = mul(preCnt[i - 1], sufSum - sufCnt) -
mul(sufCnt, preSum[i - 1]);
mid %= MOD;
mid = mul(mid, i);
LL rht = mul(sufSum - sufCnt, preSum[i - 1]);
ans += sqI + mid - rht;
ans %= MOD;
// prln(ans);
}
ans = (ans + MOD) % MOD;
printf("%I64d\n", ans);
}
return 0;
}
代码二:
//
// Created by TaoSama on 2016-08-03
// Copyright (c) 2016 TaoSama. All rights reserved.
//
#pragma comment(linker, "/STACK:102400000,102400000")
#include <algorithm>
#include <cctype>
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <ctime>
#include <iomanip>
#include <iostream>
#include <map>
#include <queue>
#include <string>
#include <set>
#include <vector>
using namespace std;
#define pr(x) cout << #x << " = " << x << " "
#define prln(x) cout << #x << " = " << x << endl
const int N = 1e6 + 10, INF = 0x3f3f3f3f, MOD = 1e9 + 7;
// 原串 a[i]: w a a b w s w f d
// 新串 s[i]: # w # a # a # b # w # s # w # f # d #
// 辅助数组 p[i]: 1 2 1 2 3 2 1 2 1 2 1 4 1 2 1 2 1 2 1
// p[i] := 新串以 s[i] 为中心向右延伸的回文距离 + 1 (自己)
// p[i]-1 := 原串以 s[i] 为中心的回文长度
int n;
char s[N];
struct Manacher {
static const int M = N << 1;
char s[M];
int n, p[M];
int delta[2][M], sum[2][M];
void init(char* a) {
s[0] = '@'; s[1] = '#'; n = 2;
for(int i = 1; a[i]; ++i)
s[n++] = a[i], s[n++] = '#';
s[n] = 0;
}
int gao() {
int mx = 0, id, ret = 0;
for(int i = 1; i < n; ++i) {
p[i] = mx > i ? min(mx - i, p[2 * id - i]) : 1;
while(s[i - p[i]] == s[i + p[i]]) ++p[i];
if(mx < i + p[i]) mx = i + p[i], id = i;
ret = max(ret, p[i] - 1);
}
return ret;
}
inline void add(int& x, int y) {
if(y < 0) y += MOD;
if((x += y) >= MOD) x -= MOD;
}
//0->start 1->end
void process() {
memset(delta, 0, sizeof delta);
memset(sum, 0, sizeof sum);
for(int i = 1; i < n; ++i) {
int r = i + p[i] - 1, l = i - (p[i] - 1);
add(sum[0][l], r);
add(sum[0][i + 1], -r - (i - l) * (-1)); //+ r ~ i
add(delta[0][l + 1], -1);
add(delta[0][i + 1], 1); //d + -1
add(sum[1][i], i);
add(sum[1][r + 1], -i - (r - i) * (-1)); //+ i ~ l
add(delta[1][i + 1], -1);
add(delta[1][r + 1], 1); //d + -1
}
for(int i = 1; i < n; ++i) {
for(int j = 0; j < 2; ++j) {
add(delta[j][i], delta[j][i - 1]);
add(sum[j][i], sum[j][i - 1]);
add(sum[j][i], delta[j][i]);
}
}
}
bool ok(int l, int r) {
l <<= 1; r <<= 1;
int k = l + r >> 1;
return k + p[k] - 1 >= r;
}
} ma;
int quick(int x, int n) {
int ret = 1;
for(; n ; n >>= 1) {
if(n & 1) ret = 1LL * ret * x % MOD;
x = 1LL * x * x % MOD;
}
return ret;
}
const int invTwo = 500000004;
int main() {
#ifdef LOCAL
freopen("C:\\Users\\TaoSama\\Desktop\\in.txt", "r", stdin);
// freopen("C:\\Users\\TaoSama\\Desktop\\out.txt","w",stdout);
#endif
ios_base::sync_with_stdio(0);
while(scanf("%s", s + 1) == 1) {
ma.init(s);
ma.gao();
ma.process();
n = strlen(s + 1);
int* preSum = ma.sum[0], *sufSum = ma.sum[1];
int ans = 0;
for(int i = 1; i < n; ++i) {
sufSum[i << 1] = 1LL * sufSum[i << 1] * invTwo % MOD;
preSum[i + 1 << 1] = 1LL * preSum[i + 1 << 1] * invTwo % MOD;
ans += 1LL * sufSum[i << 1] * preSum[i + 1 << 1] % MOD;
if(ans >= MOD) ans -= MOD;
}
printf("%d\n", ans);
}
return 0;
}